#' README 
#' summarize state-level prediction
#' the output file twostep_pred.csv has result of joint model for 44 states
#' the output file ind_pred.csv has result of stand-alone model, which is used for US.AK US.HI US.ME US.MT US.ND US.SD US.VT
#' the output file relative_mse.csv shows the relative MSE to the naive of our ARGOX

options(echo=TRUE)
library(xts)
library(glmnet)
library(argo)
library(parallel)
library(data.table)
library(boot)

gt_end_months <- c("2014-10", "2015-04", "2015-10", "2016-04", "2016-10", "2017-04", "2017-10", "2018-04", "2018-10", "2019-04", "2019-10", "2020-03")
env_list = list()
for(gte in gt_end_months){
  env_list[[gte]] <- new.env()
  load(paste0("argo2_mix0.33/argo2-state-GT-pub_history_ending_", gte, ".rda"), envir = env_list[[gte]])
}
env_list_mixed033 <- env_list

twostep_pred <- env_list$`2014-10`$argo2_result_sub$twostep
twostep_pred <- merge(twostep_pred, env_list$`2020-03`$ili_national)
twostep_pred$env_list..2020.03..ili_national <- NULL

for(gte in gt_end_months[2:length(gt_end_months)]){
  twostep_latest <- na.omit(env_list[[gte]]$argo2_result_sub$twostep)
  for(each_col in colnames(twostep_latest)){
    twostep_pred[index(twostep_latest), each_col] <- data.matrix(twostep_latest[,each_col])  
  }
}
twostep_pred <- na.omit(twostep_pred)
write.csv(data.frame(twostep_pred[,sort(colnames(twostep_pred))]), "twostep_pred.csv")

for(gte in gt_end_months){
  env_list[[gte]] <- new.env()
  load(paste0("argo2_mix0/argo2-state-GT-pub_history_ending_", gte, ".rda"), envir = env_list[[gte]])
}
env_list_purestate <- env_list

outtable <- format(round(as.matrix(sapply(env_list_purestate$`2020-03`$GT_state, function(x) mean(x==0)))*100, 2), nsmall=2)
outtable[] <- paste0(outtable, "%")
xtable::xtable(t(outtable["US",]))
outtable <- outtable[setdiff(rownames(outtable), "US"),]
outtable <- outtable[order(names(outtable))]
outtable <- c(outtable[setdiff(names(outtable), "US.NY.NYC")], outtable['US.NY.NYC'])
print(xtable::xtable(t(outtable[1:10])), include.rownames=FALSE)
print(xtable::xtable(t(outtable[11:20])), include.rownames=FALSE)
print(xtable::xtable(t(outtable[21:30])), include.rownames=FALSE)
print(xtable::xtable(t(outtable[31:40])), include.rownames=FALSE)
print(xtable::xtable(t(outtable[41:50])), include.rownames=FALSE)
print(xtable::xtable(t(outtable[51:60])), include.rownames=FALSE)

matrix(c(sapply(env_list_purestate$`2020-03`$GT_state, function(x) mean(x==0))*100, rep(NA, 3)), ncol = 7)

ind_pred <- env_list$`2014-10`$argo_ind_result$twostep
ind_pred <- merge(ind_pred, env_list$`2020-03`$ili_national)
ind_pred$env_list..2020.03..ili_national <- NULL

for(gte in gt_end_months[2:length(gt_end_months)]){
  twostep_latest <- na.omit(env_list[[gte]]$argo_ind_result$twostep)
  for(each_col in colnames(twostep_latest)){
    ind_pred[index(twostep_latest), each_col] <- data.matrix(twostep_latest[,each_col])  
  }
}
ind_pred <- na.omit(ind_pred)
write.csv(data.frame(ind_pred[,sort(colnames(ind_pred))]), "ind_pred.csv")

ili_state <- env_list[[length(env_list)]]$ili_state
state_names <- env_list[[length(env_list)]]$state_names
state_names_sub <- env_list[[length(env_list)]]$state_names_sub
pred_state <- env_list[[length(env_list)]]$pred_state

naive.pred <- ili_state
index(naive.pred) <- index(naive.pred) + 7

pred_comb <- ind_pred
pred_comb[,state_names_sub] <- twostep_pred[,state_names_sub]
pred_comb$env_list..length.env_list....ili_national <- NULL
pdf("state_pred.pdf", width = 12)
for (i in colnames(pred_comb)){
  print(plot.xts(cbind(pred_comb[, i], ili_state[, i]), main = i))
}
dev.off()

err_argo2 <- colMeans((twostep_pred - ili_state[index(twostep_pred),state_names_sub])["2014-10-11/"]^2)
err_ind <- colMeans((ind_pred - ili_state[,state_names])["2014-10-11/"]^2)
err_naive <- colMeans((naive.pred[,state_names] - ili_state[,state_names])["2014-10-11/"]^2)

err_df <- data.frame(err_ind)
err_df$err_argo2 <- NA
err_df[names(err_argo2), "err_argo2"] <- err_argo2
err_df$err_naive <- NA
err_df[names(err_naive), "err_naive"] <- err_naive

relative_mse <- err_df
relative_mse$err_ind <- relative_mse$err_ind / relative_mse$err_naive
relative_mse$err_argo2 <- relative_mse$err_argo2 / relative_mse$err_naive
relative_mse$err <- relative_mse$err_argo2
relative_mse[setdiff(state_names, state_names_sub), "err"] <- relative_mse[setdiff(state_names, state_names_sub), "err_ind"]
relative_mse <- relative_mse[,c("err_argo2", "err_ind", "err")]
barplot(relative_mse[,"err"])
abline(h=1, col=2)
relative_mse[relative_mse$err > 1, ]


write.csv(relative_mse, "relative_mse.csv")
write.csv(err_df, "err_df.csv")

# multiple correlation
ili_regional <- env_list$`2014-10`$ili_regional
ili_national <- env_list$`2014-10`$ili_national
state_info <- env_list$`2014-10`$state_info

time_frame <- "2010-10-09/2014-09-27"
multicorr <- c()
state_names_rm_hi_ak <- setdiff(state_names, c("US.HI", "US.AK"))
for(each_state in state_names){
  region_id_for_state = state_info[Abbre==strsplit(each_state, "\\.")[[1]][2], Region]
  ilireg <- ili_regional[time_frame,-region_id_for_state]
  ilinat <- ili_national[time_frame]
  iistate <- ili_state[time_frame, setdiff(state_names_rm_hi_ak, each_state)]  # remove and HI AK in multiple correlation calculation
  response <- ili_state[time_frame, each_state]
  multicorr[each_state] <- summary(lm(response ~ cbind(iistate, ilireg, ilinat)))$r.sq
}
write.csv(data.frame(multicorr), "multicorr.csv")
xtable::xtable(data.frame(multicorr), digits=4)
